AD in DiffEqFlux using DiscreteCallback

I am trying to implement a callback term in my ODE solver and using Zygote to differentiate with respect to the model parameters. The action that callback uses is very simple and not very different to the one used in Event Handling and Callback Functions. I am using a function of the form f(u,p,t) (no in-place) to define the dynamics and this is my callback function.

iceflow_prob = ODEProblem(iceflow_UDE_batch,
                          H,
                          tspan,
                          tstops=tstops,
                          θ)
 
function stop_condition(u,t,integrator) 
    t in tstops
end
function action!(integrator)
    integrator.u *= 1.001
end

cb_MB = DiscreteCallback(stop_condition, action!)
    
iceflow_sol = solve(iceflow_prob, 
                    ROCK4(), 
                    callback=cb_MB,
                    tstops=tstops,
                    u0=H, 
                    p=θ,
                    sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()),
                    reltol=10f-6, 
                    save_everystep=true,
                    progress=true, 
                    progress_steps = 100)

When I run this with integrator.u *= 1.001, I receive the same error that was reported in this issue in DiffEqFlux:

setfield!: immutable struct of type FakeIntegrator cannot be changed

Following @ChrisRackauckas suggestion in the same issue, I changed the line by integrator.u .*= 1.001 instead, but then I received a different error message.

Mutating arrays is not supported -- called copyto!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

In the Zygote documentation they give a few advices of how to fix this, but I cannot make it work here. What I also don’t understand is why in this simpler toy example (also based in the tutorial) the same workflow works perfectly fine.

using DifferentialEquations
using Zygote, SciMLSensitivity
using DiffEqFlux

function dynamics(u, p, t)
     return - p .* u
end

condition(u,t,integrator) = t==4
effect!(integrator) = integrator.u[1] += 10
cb = DiscreteCallback(condition, effect!)

t₁ = 10.0
u0 = [10.0]
p₀ = 0.1
p = [p₀]

prob = ODEProblem(dynamics, u0, (0.0,t₁), tstops=[4.0], p)
sol = solve(prob, Tsit5(), callback=cb, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

I am not sure what is the problem… the code can be found in this fork of the ODINN.jl project (work in collaboration with @JordiBolibar). The gradients are perfectly well computed when we don’t use callbacks, but I would like this to work with callbacks too. Sorry I cannot provide a full MWE, but the script I am running can be found inside batch_iceflow_UDE in this script.

1 Like

[Based on @frankschae response in Julia Slack]

Currently just ReverseDiffVJP() is supported for callbacks. After changing sensealg to be InterpolatingAdjoint(autojacvec=ReverseDiffVJP())) the previous code started working.

3 Likes