BoundError in Flux.train! with DiffEqFlux

Hi,

I’m trying to use DiffEqFlux to solve a scalar ode which looks like : du(t) = f(X(t)) - a * u, where X(t) is a set of functions of time. I modeled f with a neural network and i’m trying to optimize the parameter of this neural network and the parameter “a” to match some data.
However, I get a BoundError when performing training and I can’t figure out why.

Here is a MWE:

u_data = randn(101)
u0 = u_data[1]
u_est = (t, i) -> t ^ i
m = Chain(Dense(10, 11, tanh), Dense(11, 1))
pm,re_m = Flux.destructure(m)
n_param = length(pm)
p_ode = [pm ; 1.0]
du(u, p, t) = re_m(p[1:n_param])([u_est(t, i) for i = 1:10])[1] - p[end] * u
prob_ode = ODEProblem(du, u0, (0.0, 10.0), p_ode)
predict_adjoint() = concrete_solve(prob_ode, Tsit5(), u0, saveat = 0:0.1:10, abstol = 1e-6,
                 reltol = 1e-6, sensealg = InterpolatingAdjoint(checkpointing = true))
loss_adjoint() = sum(abs2, predict_adjoint().u .- u_data)
Flux.train!(loss_adjoint,
            Flux.params(p_ode),
            Iterators.repeated((), 10),
            ADAM(0.05),
            cb = () -> println("loss :", round(loss_adjoint(), digits = 2)))

I also noticed that I get the same error with Zygote.gradient:

Zygote.gradient(p -> loss_adjoint(p), p_ode)

Many thanks in advance for helping me understand what I doing wrongly

Scalar values for parameters aren’t supported in Flux, so u0 needed to be a vector. The following trains:

using DiffEqFlux, DiffEqSensitivity
u_data = randn(101)
u0 = [u_data[1]]
u_est = (t, i) -> t ^ i
m = Chain(Dense(10, 11, tanh), Dense(11, 1))
pm,re_m = Flux.destructure(m)
n_param = length(pm)
p_ode = [pm ; 1.0]
du(u, p, t) = re_m(p[1:n_param])([u_est(t, i) for i = 1:10])[1] .- p[end] * u
prob_ode = ODEProblem(du, u0, (0.0, 10.0), p_ode)
predict_adjoint() = concrete_solve(prob_ode, Tsit5(), u0, saveat = 0:0.1:10, abstol = 1e-6,
                 reltol = 1e-6, sensealg = InterpolatingAdjoint(checkpointing = true))
loss_adjoint() = sum(abs2, predict_adjoint() .- u_data)
Flux.train!(loss_adjoint,
            Flux.params(p_ode),
            Iterators.repeated((), 10),
            ADAM(0.05),
            cb = () -> println("loss :", round(loss_adjoint(), digits = 2)))

Thank you for the quick answer. Indeed, this works fine, I didn’t know for the scalar and Flux !

Edit: I just noticed that in loss function predict_adjoint() should be replace by reduce(vcat, predict_adjoint()) to compute the right loss.