Hello!

Say I have a differential equation of the form *du/dt = NN(u)*, where *NN(u)* is the output of a neural network that takes as input the current value of u.

The boundary condition is *u(0) = 0*.

I want to train the neural network in such a way that, for a specified time *t=2* and position *uf=1*, the solution *u* to the differential equation obeys *u(t)=uf*.

This is meant to be a toy version of the problem of teaching an agent to find paths…

After reading about Flux and Differential equations, I figured that I could approach the problem with the following code

```
using Flux: train!, params
using DifferentialEquations
using SciMLSensitivity
# define NN model
model = Chain(
Dense(1, 4, relu),
Dense(4, 1, x -> σ.(x))
)
ps = params(model)
# define differential equation
function f!(du, u, p, t)
du[1] = model([u[1]])[1] # NN controls velocity
end
# calculate final position for our NN controller
function final_position()
u0 = [0.] # starts at 0.
tspan = (0.0, 2.0) # systems evolves for time = 2.
prob = ODEProblem(f!, u0, tspan) # set problem
sol = solve(prob, Tsit5(), save_everystep = false, save_start = false) # solve diff equation
Xf = sol.u[1][1] #final position
return Xf
end
# define loss function
uf = 1.
function loss()
Xf = final_position() # final position with NN controller starting at X0
return (Xf - uf)^2
end
# set optimizer
opt = ADAM(0.3)
# define (empty) data
x_train = Iterators.repeated((), 100)
# do one training round
train!(loss, ps, x_train, opt)
```

But I got the following warning

and then the parameters are not updated at all.

I guess that this is happening because it is having trouble auto-differentiating a function of the output of a differential equation…

But I could not find a way to work around it.

Any help here would be immensely appreciated!

Thanks in advance!

I also believe that I may not be using the appropriate (or up to date) Julia framework for this type of problem.

How would you approach the problem?

Btw, the package versions are:

```
[f6369f11] ForwardDiff v0.10.35
[91a5bcdd] Plots v1.38.14
[1ed8b502] SciMLSensitivity v7.32.0
[90137ffa] StaticArrays v1.5.25
[e88e6eb3] Zygote v0.6.61```
```