I’m trying to take the gradient of a loss function (which includes solving an ODEproblem with DifferentialEquations.jl) with respect to the parameters of a neural network. I was able to get ForwardDiff working for this, but the problem is that forward mode differentiation scales terribly with the number of parameters since we have many parameters to one loss value. This problem is better suited to reverse mode autodiff, but both Zygote and Enzyme are giving me problems.
begin
# Import needed packages
using Flux
using Zygote
using ForwardDiff
using DifferentialEquations
using Plots
using Enzyme
using SciMLSensitivity
using LinearAlgebra
end
function u_true(t)
u_0 = 1
u = u_0*exp.(t.^3/3)
return u
end
function f2(u, p, t) # the actual ODE
m = re(p)
return eltype(p).(m([t'])[1]*u) # explain why we use eltype, explain why t' (shape issues). Eplain how re(p) is putting it back into shape so it can be used
end
function eval_model(t,p)
u0 = eltype(t)(1.0)
tspan = eltype(t).((0.0, 1.0))
prob = DifferentialEquations.ODEProblem(f2,u0,tspan,p)
sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
return Array(sol.u)
end
function loss(t, p, y_true)
# recall that the model needs to be used as a component of an equation, not compared directly to our training data
y_nn = eval_model(t, p)
return Flux.Losses.mse(y_nn, y_true)
end
n_in = 1
n_out = 1
model = Chain(
Dense(n_in,10,relu),
Dense(10,10,relu),
Dense(10,10,relu),
Dense(10,n_out));
# Test the eval_model and loss functions
t = Vector{Float32}(LinRange(0, 1, 10))
p, re = Flux.destructure(model)
y_true = eltype(t).(u_true(t))
loss_fd(x) = loss(t, x, y_true)
# grads = ForwardDiff.gradient(loss_fd, p)
grads = Zygote.jacobian( p -> loss_fd(p), p)
With the above code, I get warnings that all Reverse-Mode VJP choices have failed, and that Zygote will fall back to numerical VJPs. Then I get the following error from SciMLSensitivity:
MethodError: no method matching similar(::Float32, ::Int64, ::Int64)
The function `similar` exists, but no method is defined for this combination of argument types.
Closest candidates are:
similar(::Type{T}, ::Union{Integer, AbstractUnitRange}...) where T<:AbstractArray
@ Base abstractarray.jl:866
similar(::BandedMatrices.AbstractBandedMatrix, ::Integer, ::Integer)
@ BandedMatrices ~/.julia/packages/BandedMatrices/KJZ2p/src/banded/BandedMatrix.jl:374
similar(::BandedMatrices.AbstractBandedMatrix, ::Integer, ::Integer, ::Integer, ::Integer)
@ BandedMatrices ~/.julia/packages/BandedMatrices/KJZ2p/src/banded/BandedMatrix.jl:375
...
When I try Enzyme, it complains about the non-constant keyword (t for time in this case). There may be workarounds involving interpolating external to the ODEproblem solver, but that would bring in uncertain sources of error and I’d prefer to find a better solution.