Hey All,
I’ve been meaning to calculate the gradient of an ODEProblem using Zygote, OrdinaryDiffEq, and DiffEqSensitivity. I’ve been running into a wide scala of errors along the way but found a way to do it.
So initially I follow the docs here: https://docs.sciml.ai/latest/analysis/sensitivity/
This is my code:
using DiffEqSensitivity, OrdinaryDiffEq, Zygote
function ode(u, p, t)
CL, V, F, kₐ, dose = p
k = CL / V
return (F * kₐ * dose * exp(-kₐ * t)) - (k * u)
end
p = [1., 10., .5, 1., 100.]
u0 = 0.
t = [0.5, 1., 2., 3., 6., 9., 12., 24., 36., 48., 72.]
y = [0., 1.9, 3.3, 6.6, 9.1, 10.8, 8.6, 5.6, 4, 2.7, 0.8]
prob = ODEProblem(fiip, u0, (0.0, 150.0), p)
sol = solve(prob, Tsit5())
function predict(p)
_prob = remake(prob, p=p)
ŷ = solve(_prob, Tsit5(), saveat=0.1, sensealg=QuadratureAdjoint()).u
return sum(abs2, ŷ .- y) # loss function
end
grad = gradient(predict, p)
This returns the error: ERROR: LoadError: MethodError: no method matching similar(::Float64, ::Int64)
The error and the stack trace aren’t very helpful to figure out what causes this error.
In the end I found it was because u0 in this case was a Float64, instead of a Array{Float64}.
Changing the line u0 = [0.]
and the ode function to
function ode(u, p, t)
CL, V, F, kₐ, dose = p
k = CL / V
return (F * kₐ * dose * exp(-kₐ * t)) .- (k * u) # - to .-
end
the gradients actually get calculated.
Why does u0 have to be an Array? This is not really documented in the doc I listed so it might nice to either add that or add a more understandable error message in the ODEProblem function when gradient is called on it.
Also I do not like that I have to change my ode function (and in the process supporting u0 with multiple dimensions), so I figured I could add a line like u = first(u)
but that also breaks the script with the following error:
ERROR: LoadError: MethodError: Cannot
convert an object of type Float64 to an object of type Array{Float64,1}
Does this mean only u0 is an array and other instances of u are just normal Floats? Again the question, why does u0 need to be an Array if that is the case?
Maybe someone is willing to discuss this?