Autodiffing with Enzyme through a DifferentialEquations ODEproblem

I’m trying to autodiff an ODE problem that I’m solving with DifferentialEquations.jl. Here’s my code:

begin
	# Import needed packages
	using DifferentialEquations
	using Enzyme
	#using Zygote
	#using SciMLSensitivity
end

function exp_ode(t)
	f(u,p,t) = 0.98u
	u0 = 1.0
	tspan = (0.0,1.0)
	prob = DifferentialEquations.ODEProblem(f,u0,tspan)
	sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
	return sol.u
end

times = Vector(LinRange(0, 1, 10))
du_dt = Enzyme.jacobian(set_runtime_activity(Reverse), t -> exp_ode(t), times)

The error message:

Enzyme execution failed.
Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), EnzymeCore.Duplicated{@NamedTuple{abstol::Float64, reltol::Float64, saveat::Vector{Float64}}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.RevConfigWidth{1, true, true, (false, true, false, false, false), true}, EnzymeCore.Const{typeof(DiffEqBase.solve_up)}, Type{EnzymeCore.Duplicated{Any}}, EnzymeCore.Duplicated{SciMLBase.ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, Main.var"workspace#102".var"#f#1", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}}, EnzymeCore.Const{Nothing}, EnzymeCore.Active{Float64}, EnzymeCore.Const{SciMLBase.NullParameters}}

This is on Julia 1.11.3 and Enzyme v0.13.30

I think the problem might be that you’re trying to differentiate with respect to a keyword argument? I’m not sure but I don’t think you can do that in Enzyme.

That would certainly make sense. But in that case, how is it possible to differentiate wrt time with this sort of problem? I don’t see another way to use ODEproblem to make t a non-keyword argument.

You could interpolate the solution and then take its derivative (DataInterpolation.jl for instance), note that f(sol,p,t) should be the derivative though by definition

2 Likes

That’s not a bad idea. Good point on the f(…) being the derivative, but the goal was to show “you can do this”, and I suppose I learned you can’t (directly). I will try a more interesting function and taking the derivative wrt parameters instead of time.

You will need SciMLSensitivity.jl

Yeah you cannot differentiate w.r.t. the save points with Enzyme because it doesn’t handle keyword arguments like that.

For the record, you can do this with ForwardDiff. You would need to do a trick of forcing duals on your state and time though, i.e.:

	u0 = eltype(t)(1.0)
	tspan = eltype(t).((0.0,1.0))

and then ForwardDiff.jl would give you the derivatives w.r.t. the save points t. This is used in optimal experimental design. I know

uses this

3 Likes

This is great, thank you Chris. Some quirks about the output I’m getting:

function exp_ode(times)
	f(u,p,t) = 0.98*u
	u0 = eltype(times)(1.0)
	tspan = eltype(times).((0.0, 1.0))
	prob = DifferentialEquations.ODEProblem(f,u0,tspan)
	sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=times)
	return sol.u
end

begin
	times = Vector(LinRange(0, 1, 10))
	p = 0.98
	du_dt = ForwardDiff.jacobian(exp_ode, times)
	du_dt_diag = diag(du_dt)
end

du_dt is a 10x10 matrix, the jacobian. I know I can extract the derivative from the diagonal, but it’s frustrating that I can’t just use FD.derivative() since technically it’s a vector values function (in reality, a scalar function evaluated at many point). Is there a more elegant solution than using the jacobian?

The quirk is that du_dt(0) and du_dt(1) evaluated this way both come out to 0 (upper left and bottom right):

10×10 Matrix{Float64}:
 0.0  0.0      0.0      0.0      0.0      0.0      0.0     0.0      0.0      0.0
 0.0  1.09274  0.0      0.0      0.0      0.0      0.0     0.0      0.0      0.0
 0.0  0.0      1.21844  0.0      0.0      0.0      0.0     0.0      0.0      0.0
 0.0  0.0      0.0      1.35861  0.0      0.0      0.0     0.0      0.0      0.0
 0.0  0.0      0.0      0.0      1.51491  0.0      0.0     0.0      0.0      0.0
 0.0  0.0      0.0      0.0      0.0      1.68918  0.0     0.0      0.0      0.0
 0.0  0.0      0.0      0.0      0.0      0.0      1.8835  0.0      0.0      0.0
 0.0  0.0      0.0      0.0      0.0      0.0      0.0     2.10017  0.0      0.0
 0.0  0.0      0.0      0.0      0.0      0.0      0.0     0.0      2.34177  0.0
 0.0  0.0      0.0      0.0      0.0      0.0      0.0     0.0      0.0      0.0

Is this a result of our workaround with the keyword argument?

the easiest thing would be just use the analytical solution to the sparse coloring problem, which would be to just do the [1,1,1,...,1] vector as the directional derivative. What this would look like is:

ForwardDiff.partials(exp_ode(ForwardDiff.Dual.(times, 1)))[1]

which should give you that vector. Why that would work… is a much longer story :sweat_smile:

2 Likes