Scalar values for parameters aren’t supported in Flux, so u0 needed to be a vector. The following trains:
using DiffEqFlux, DiffEqSensitivity
u_data = randn(101)
u0 = [u_data[1]]
u_est = (t, i) -> t ^ i
m = Chain(Dense(10, 11, tanh), Dense(11, 1))
pm,re_m = Flux.destructure(m)
n_param = length(pm)
p_ode = [pm ; 1.0]
du(u, p, t) = re_m(p[1:n_param])([u_est(t, i) for i = 1:10])[1] .- p[end] * u
prob_ode = ODEProblem(du, u0, (0.0, 10.0), p_ode)
predict_adjoint() = concrete_solve(prob_ode, Tsit5(), u0, saveat = 0:0.1:10, abstol = 1e-6,
reltol = 1e-6, sensealg = InterpolatingAdjoint(checkpointing = true))
loss_adjoint() = sum(abs2, predict_adjoint() .- u_data)
Flux.train!(loss_adjoint,
Flux.params(p_ode),
Iterators.repeated((), 10),
ADAM(0.05),
cb = () -> println("loss :", round(loss_adjoint(), digits = 2)))