Training NN with Loss from Differential Equation


Say I have a differential equation of the form du/dt = NN(u), where NN(u) is the output of a neural network that takes as input the current value of u.
The boundary condition is u(0) = 0.
I want to train the neural network in such a way that, for a specified time t=2 and position uf=1, the solution u to the differential equation obeys u(t)=uf.
This is meant to be a toy version of the problem of teaching an agent to find paths…

After reading about Flux and Differential equations, I figured that I could approach the problem with the following code

using Flux: train!, params
using DifferentialEquations
using SciMLSensitivity

# define NN model
model = Chain(
    Dense(1, 4, relu),
    Dense(4, 1, x -> σ.(x))
ps = params(model)

# define differential equation
function f!(du, u, p, t) 
    du[1] = model([u[1]])[1] # NN controls velocity

# calculate final position for our NN controller
function final_position()
    u0 = [0.] # starts at 0.
    tspan = (0.0, 2.0) # systems evolves for time = 2.
    prob = ODEProblem(f!, u0, tspan) # set problem
    sol = solve(prob, Tsit5(), save_everystep = false,  save_start = false) # solve diff equation
    Xf = sol.u[1][1] #final position
    return Xf

# define loss function
uf = 1.
function loss()
    Xf = final_position() # final position with NN controller starting at X0
    return (Xf - uf)^2

# set optimizer
opt = ADAM(0.3)

# define (empty) data
x_train = Iterators.repeated((), 100)

# do one training round
train!(loss, ps, x_train, opt)

But I got the following warning

and then the parameters are not updated at all.

I guess that this is happening because it is having trouble auto-differentiating a function of the output of a differential equation…
But I could not find a way to work around it.

Any help here would be immensely appreciated!
Thanks in advance!

I also believe that I may not be using the appropriate (or up to date) Julia framework for this type of problem.
How would you approach the problem?

Btw, the package versions are:

  [f6369f11] ForwardDiff v0.10.35
  [91a5bcdd] Plots v1.38.14
  [1ed8b502] SciMLSensitivity v7.32.0
  [90137ffa] StaticArrays v1.5.25
  [e88e6eb3] Zygote v0.6.61```