I’m new to using autodiff in Julia and have some issues making Enzyme.jl work with OrdinaryDiffEq.jl. I want to take the a reverse mode derivative of my differential equation solution with respect to the input parameter p. This is a simple example, but in my actual model I need the integrator interface of OrdinaryDiffEq as well, so I am looking for a solution to make Enzyme work in reverse mode with the integrator interface. Here is my implementation:
using OrdinaryDiffEq
using Enzyme
using SciMLSensitivity
function fun(u, p, t)
return -p[1] * t
end
function test_fun(p; kwargs...)
u0 = 2.0
tspan = (1.0, 2.0)
prob = ODEProblem{false}(fun, u0, tspan, p)
integrator = OrdinaryDiffEq.init(prob, OrdinaryDiffEq.RK4(), save_everystep = false; kwargs...)
sol = solve!(integrator).u[2]
return sol
end
p = [1.0]
dp = [0.0]
y = [0.0]
dy = [1.0]
Enzyme.autodiff(Reverse, test_fun, Duplicated(p, dp), Duplicated(y, dy))
I have tried multiple iterations for this, but nothing seems to work. Any inputs would be helpful!
You can definitely do this with Enzyme, but I suggest using DIfferentiationInterface.jl with the Enzyme backend. I find that APi to be more intuitive and you also get access to use other backends for free.
using OrdinaryDiffEq
using Enzyme
using SciMLSensitivity
function fun(du,u, p, t)
du .= -p[1] * t
nothing
end
function test_fun(p,prob)
prob = remake(prob, p=p)
sol = solve(prob,RK4(),save_everystep=false)
res = sol.u[2]
return res[1]
end
p = [1.0]
dp = [0.0]
u0 = [2.0]
prob = ODEProblem{true}(fun, u0, (1.0, 2.0), p)
dprob = Enzyme.make_zero(prob)
Enzyme.autodiff(Enzyme.Reverse, test_fun, Duplicated(p, dp),DuplicatedNoNeed(prob,dprob))
@info dp
The only thing that you can’t go around is making u0 a vector I think, because SciMLSensitivity needs to go through your du (must be a Ref), also always diff a remake not the original ODEProblem creation, it may work but its really not beautiful this way. You could go back to an explicit def for “fun” but since I made u0 a vector its better to avoid it
@yolhan_mannes Adding to this, do you have an idea how to make it work with the integrator interface? Because it seems to work with the direct solve, but not when I try to use the integrator interface
With enzyme that is not fully supported yet. I hope to change that in the next few months but it will take a bit, and there are certain trade offs with that
@yolhan_mannes I’m not exactly sure what you meant, but I need to solve the differential equation one step at a time, perform some calculation, and then move to the next step. I could do save_everystep = true, and then store the solution, and perform the calculation later, but this is slower and hence I use the step! function from the integrator interface.
@ChrisRackauckas Hi Chris! Thanks for the update. Is the integrator interface supported with other packages such as Zygote, or ReverseDiff?