Making Enzyme.jl work with the integrator interface of OrdinaryDiffEq.jl

I’m new to using autodiff in Julia and have some issues making Enzyme.jl work with OrdinaryDiffEq.jl. I want to take the a reverse mode derivative of my differential equation solution with respect to the input parameter p. This is a simple example, but in my actual model I need the integrator interface of OrdinaryDiffEq as well, so I am looking for a solution to make Enzyme work in reverse mode with the integrator interface. Here is my implementation:

using OrdinaryDiffEq
using Enzyme
using SciMLSensitivity

function fun(u, p, t)
    return -p[1] * t
end

function test_fun(p; kwargs...)
    u0 = 2.0
    tspan = (1.0, 2.0)
    prob = ODEProblem{false}(fun, u0, tspan, p)
    integrator = OrdinaryDiffEq.init(prob, OrdinaryDiffEq.RK4(), save_everystep = false; kwargs...)
    sol = solve!(integrator).u[2]
    return sol

end

p = [1.0]
dp = [0.0]
y = [0.0]
dy = [1.0]
Enzyme.autodiff(Reverse, test_fun, Duplicated(p, dp), Duplicated(y, dy))

I have tried multiple iterations for this, but nothing seems to work. Any inputs would be helpful!

You can definitely do this with Enzyme, but I suggest using DIfferentiationInterface.jl with the Enzyme backend. I find that APi to be more intuitive and you also get access to use other backends for free.

1 Like

Works fine on 1.10, seems to be a 1.11 issue which are typical in Enzyme, as you, did not find a workaround

1 Like

@kylebeggs Thanks for the tip, that’ll be useful

@yolhan_mannes Can you let me know the exact version of Julia you used to execute it, the version of Enzyme, OrdinaryDiffEq and SciMLSensitivity?

Julia Version 1.10.8
[7da242da] Enzyme v0.13.30
[1dea7af3] OrdinaryDiffEq v6.91.0
[1ed8b502] SciMLSensitivity v7.74.0

I also changed your code a little while trying to make it work on 1.11 (did not make it : issue SciMLSensitivity + Enzyme + 1.11 issue · Issue #2318 · EnzymeAD/Enzyme.jl · GitHub),

using OrdinaryDiffEq
using Enzyme
using SciMLSensitivity

function fun(du,u, p, t)
    du .= -p[1] * t
    nothing
end

function test_fun(p,prob)
    prob = remake(prob, p=p)
    sol = solve(prob,RK4(),save_everystep=false)
    res = sol.u[2]
    return res[1]
end


p = [1.0]
dp = [0.0]
u0 = [2.0]
prob = ODEProblem{true}(fun, u0, (1.0, 2.0), p)
dprob = Enzyme.make_zero(prob)

Enzyme.autodiff(Enzyme.Reverse, test_fun, Duplicated(p, dp),DuplicatedNoNeed(prob,dprob))
@info dp

The only thing that you can’t go around is making u0 a vector I think, because SciMLSensitivity needs to go through your du (must be a Ref), also always diff a remake not the original ODEProblem creation, it may work but its really not beautiful this way. You could go back to an explicit def for “fun” but since I made u0 a vector its better to avoid it

1 Like

@yolhan_mannes Thank you so much! It was indeed a version issue, seems to be working with just 1.10

1 Like

@yolhan_mannes Adding to this, do you have an idea how to make it work with the integrator interface? Because it seems to work with the direct solve, but not when I try to use the integrator interface

1 Like

What about it using a callback to interface with it ? Also, do you have an exemple of kw pass to the init and not pass to solve ?

With enzyme that is not fully supported yet. I hope to change that in the next few months but it will take a bit, and there are certain trade offs with that

@yolhan_mannes I’m not exactly sure what you meant, but I need to solve the differential equation one step at a time, perform some calculation, and then move to the next step. I could do save_everystep = true, and then store the solution, and perform the calculation later, but this is slower and hence I use the step! function from the integrator interface.

@ChrisRackauckas Hi Chris! Thanks for the update. Is the integrator interface supported with other packages such as Zygote, or ReverseDiff?

You might be able to use a discrete callback for that:

Zygote no, ReverseDiff only in slow scalar mode. So Enzyme is really needed for performance here. It’s fundamental to the AD libraries.

That can be done with a discrete callback.