Autodiffing through DE.jl ODEproblem wrt neural network parameters in reverse mode

I’m trying to take the gradient of a loss function (which includes solving an ODEproblem with DifferentialEquations.jl) with respect to the parameters of a neural network. I was able to get ForwardDiff working for this, but the problem is that forward mode differentiation scales terribly with the number of parameters since we have many parameters to one loss value. This problem is better suited to reverse mode autodiff, but both Zygote and Enzyme are giving me problems.

begin
	# Import needed packages
	using Flux
	using Zygote
	using ForwardDiff
	using DifferentialEquations
	using Plots
	using Enzyme
	using SciMLSensitivity
	using LinearAlgebra
end

function u_true(t)
	u_0 = 1
	u = u_0*exp.(t.^3/3)
	return u
end

function f2(u, p, t) # the actual ODE
	m = re(p)
	return eltype(p).(m([t'])[1]*u) # explain why we use eltype, explain why t' (shape issues). Eplain how re(p) is putting it back into shape so it can be used
end
function eval_model(t,p)
	u0 = eltype(t)(1.0)
	tspan = eltype(t).((0.0, 1.0))
	prob = DifferentialEquations.ODEProblem(f2,u0,tspan,p)
	sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
	return Array(sol.u)
end
function loss(t, p, y_true)
	# recall that the model needs to be used as a component of an equation, not compared directly to our training data
	y_nn = eval_model(t, p)
	return Flux.Losses.mse(y_nn, y_true)
end

n_in = 1
n_out = 1
model = Chain(
		Dense(n_in,10,relu),
		Dense(10,10,relu),
		Dense(10,10,relu),
		Dense(10,n_out));
# Test the eval_model and loss functions
t = Vector{Float32}(LinRange(0, 1, 10))
p, re = Flux.destructure(model) 
y_true = eltype(t).(u_true(t))
loss_fd(x) = loss(t, x, y_true)
# grads = ForwardDiff.gradient(loss_fd, p)
grads = Zygote.jacobian( p -> loss_fd(p), p)

With the above code, I get warnings that all Reverse-Mode VJP choices have failed, and that Zygote will fall back to numerical VJPs. Then I get the following error from SciMLSensitivity:

MethodError: no method matching similar(::Float32, ::Int64, ::Int64)
The function `similar` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  similar(::Type{T}, ::Union{Integer, AbstractUnitRange}...) where T<:AbstractArray
   @ Base abstractarray.jl:866
  similar(::BandedMatrices.AbstractBandedMatrix, ::Integer, ::Integer)
   @ BandedMatrices ~/.julia/packages/BandedMatrices/KJZ2p/src/banded/BandedMatrix.jl:374
  similar(::BandedMatrices.AbstractBandedMatrix, ::Integer, ::Integer, ::Integer, ::Integer)
   @ BandedMatrices ~/.julia/packages/BandedMatrices/KJZ2p/src/banded/BandedMatrix.jl:375
  ...

When I try Enzyme, it complains about the non-constant keyword (t for time in this case). There may be workarounds involving interpolating external to the ODEproblem solver, but that would bring in uncertain sources of error and I’d prefer to find a better solution.

Did you try the Lux versions? We use Lux everywhere in the docs now because the latest breaking change for Flux makes it pretty unusable for lots of things. I would recommend just doing the Lux versions if you want to keep things easy.