I am trying to implement a callback term in my ODE solver and using Zygote to differentiate with respect to the model parameters. The action that callback uses is very simple and not very different to the one used in Event Handling and Callback Functions. I am using a function of the form f(u,p,t)
(no in-place) to define the dynamics and this is my callback function.
iceflow_prob = ODEProblem(iceflow_UDE_batch,
H,
tspan,
tstops=tstops,
Īø)
function stop_condition(u,t,integrator)
t in tstops
end
function action!(integrator)
integrator.u *= 1.001
end
cb_MB = DiscreteCallback(stop_condition, action!)
iceflow_sol = solve(iceflow_prob,
ROCK4(),
callback=cb_MB,
tstops=tstops,
u0=H,
p=Īø,
sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()),
reltol=10f-6,
save_everystep=true,
progress=true,
progress_steps = 100)
When I run this with integrator.u *= 1.001
, I receive the same error that was reported in this issue in DiffEqFlux:
setfield!: immutable struct of type FakeIntegrator cannot be changed
Following @ChrisRackauckas suggestion in the same issue, I changed the line by integrator.u .*= 1.001
instead, but then I received a different error message.
Mutating arrays is not supported -- called copyto!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
In the Zygote documentation they give a few advices of how to fix this, but I cannot make it work here. What I also donāt understand is why in this simpler toy example (also based in the tutorial) the same workflow works perfectly fine.
using DifferentialEquations
using Zygote, SciMLSensitivity
using DiffEqFlux
function dynamics(u, p, t)
return - p .* u
end
condition(u,t,integrator) = t==4
effect!(integrator) = integrator.u[1] += 10
cb = DiscreteCallback(condition, effect!)
tā = 10.0
u0 = [10.0]
pā = 0.1
p = [pā]
prob = ODEProblem(dynamics, u0, (0.0,tā), tstops=[4.0], p)
sol = solve(prob, Tsit5(), callback=cb, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))
I am not sure what is the problemā¦ the code can be found in this fork of the ODINN.jl project (work in collaboration with @JordiBolibar). The gradients are perfectly well computed when we donāt use callbacks, but I would like this to work with callbacks too. Sorry I cannot provide a full MWE, but the script I am running can be found inside batch_iceflow_UDE
in this script.