iqopi
August 10, 2021, 11:40pm
1
Hey,
i am trying to optimize the parameters of a Differential equation using the DiffEqFlux.
Since i am only interested into the new steady state, i am using a callback to terminate the integration, once this new state is reached. This looks sth like:
condition(u,t,integrator) = sum( abs.(u[2*n+1:end]) ) < 1e-7*sum(abs.(barX))
affect!(integrator) = terminate!(integrator)
cb = DiscreteCallback(condition,affect!)
sol = solve(prob,BS3(),p=p,callback=cb)
However, if i hit this callback when calculating the gradient, i get the following error message:
ERROR: MethodError: no method matching terminate!(::DiffEqSensitivity.FakeIntegrator{Vector{Float64}, Vector{Float64}, Float64})
So it seems, the terminate!
is not implemented yet in the DiffEqSensitivity
stack.
Is there an easy way to circumvent this problem?
(sorry for not providing a complete working example. hope its still clear, what the problem is.)
Yes, a pure AD sensealg should be fine. sol = solve(prob,BS3(),p=p,callback=cb,sensealg=ReverseDiffAdjoint())
. But we should get an issue open on this anyways. However…
In this case you don’t need to differentiate through the ODE system at all! If you use SteadyStateProblem
, it has a special adjoint that treats the case nicely.
And actually I was just reviewing. https://github.com/SciML/DiffEqSensitivity.jl/pull/445 should solve this.
iqopi
August 11, 2021, 9:16am
4
Thanks so much, however, it i do that, i get a different error:
#%%
using DifferentialEquations, DiffEqFlux, Plots, Flux
#%%
function f!(du, u, p, t)
du[1] = - p[1]*u[3]
du[2] = - p[2]*u[3]
du[3] = - du[1]*u[2] - du[2]*u[1]
end
#%%
u0 = [1.,1.,1.]
tspan = (0.0, 10.0)
#%%
condition(u,t,integrator) = abs.(u)[3] < 1e-4
affect!(integrator) = terminate!(integrator)
cb = DiscreteCallback(condition,affect!)
#%%
p = [1.,.5]
prob = ODEProblem(f!, u0, tspan,p)
sol = solve(prob,BS3(),p=p,callback = cb)
fit_sol = sol.u[end]
#%%
# Setup the ODE problem, then solve
plot(sol)
#%%
function loss(pl)
sol = solve(prob,BS3(),p=pl) # ,callback=cb,sensealg=ReverseDiffAdjoint())
state = sol.u[end]
return sum(abs,sol[end][1:2] - fit_sol[1:2])
end
#%%
loss(p)
#########################
# THIS WORKS
#########################
#%%
pv = rand(2)
par = Flux.params(pv)
grads = gradient(() -> loss(pv), par )
grads[pv]
#%%
function loss(pl)
sol = solve(prob,BS3(),p=pl,callback=cb,sensealg=ReverseDiffAdjoint())
state = sol.u[end]
return sum(abs,sol[end][1:2] - fit_sol[1:2])
end
#%%
loss(p)
#####################
# Here i get an error
#####################
#%%
par = Flux.params(pv)
grads = gradient(() -> loss(pv), par )
grads[pv]
The error is
ERROR: LoadError: TrackedArrays do not support setindex!
Stacktrace:
and it points back to the sol=...
in the loss function