Hey,
Sorry it took a bit to get to this. The issue was that you were mixing an old deprecated syntax with a new one. What you wanted was:
sol = solve(
ODEProblem(VO_system!, u0, tspan, p),
Tsit5();
u0 = u0,
p = p,
saveat = range(tspan[1], tspan[2], length = 500),
sensealg = QuadratureAdjoint(),
abstol = 1e-10,
reltol = 1e-10,
)
i.e. use solve
and keyword arguments for u0
and p
. A full example that double checks the values of the gradients with ForwardDiff is:
using Flux, OrdinaryDiffEq, DiffEqSensitivity
# encapsule all parameters in VO struct
struct VO
Cth::Any
Rth::Any
end
Flux.@functor VO
Flux.trainable(device::VO) = (device.Rth, device.Cth)
# VO constructor
VO() = VO([1.63e-13], [4.0e5])
# 2nd order function of VO
function (device::VO)(Input_time_series,p = [device.Cth[1], device.Rth[1]])
tspan = (0.0, 1e-5)
u0 = vcat(repeat([1.0]; outer = [10]), repeat([300.0]; outer = [10]))
# single device paramters
function VO_system!(du, u, p, t)
Cth, Rth = p
V = u[1:10]
T = u[11:end]
du[1:10] = -V
@. du[11:end] = 1 / Cth * (V - T / Rth)
end
sol = solve(
ODEProblem(VO_system!, u0, tspan, p),
Tsit5();
u0 = u0,
p = p,
saveat = range(tspan[1], tspan[2], length = 500),
sensealg = QuadratureAdjoint(),
abstol = 1e-10,
reltol = 1e-10,
)
return sol
end
###############################################Training#########################
device = VO()
model = device
ps = Flux.params(model)
loss(x, y) = Flux.mse(model(x),y)
println("Tracked Parameters: ", ps)
gs = Flux.gradient(() -> loss(0, 0), ps)
@show gs[ps[1]]
@show gs[ps[2]]
using ForwardDiff
grad = ForwardDiff.gradient(p->Flux.mse(model(0,p),0),[1.63e-13,4.0e5])
@show grad[2]
@show grad[1]
which outputs:
Tracked Parameters: Params([[400000.0], [1.63e-13]])
gs[ps[1]] = [3.937418595344859e9]
gs[ps[2]] = [-4.787204526955042e25]
grad[2] = 3.9374185953439655e9
grad[1] = -4.787204527049037e25