Hello!
Say I have a differential equation of the form du/dt = NN(u), where NN(u) is the output of a neural network that takes as input the current value of u.
The boundary condition is u(0) = 0.
I want to train the neural network in such a way that, for a specified time t=2 and position uf=1, the solution u to the differential equation obeys u(t)=uf.
This is meant to be a toy version of the problem of teaching an agent to find paths…
After reading about Flux and Differential equations, I figured that I could approach the problem with the following code
using Flux: train!, params
using DifferentialEquations
using SciMLSensitivity
# define NN model
model = Chain(
Dense(1, 4, relu),
Dense(4, 1, x -> σ.(x))
)
ps = params(model)
# define differential equation
function f!(du, u, p, t)
du[1] = model([u[1]])[1] # NN controls velocity
end
# calculate final position for our NN controller
function final_position()
u0 = [0.] # starts at 0.
tspan = (0.0, 2.0) # systems evolves for time = 2.
prob = ODEProblem(f!, u0, tspan) # set problem
sol = solve(prob, Tsit5(), save_everystep = false, save_start = false) # solve diff equation
Xf = sol.u[1][1] #final position
return Xf
end
# define loss function
uf = 1.
function loss()
Xf = final_position() # final position with NN controller starting at X0
return (Xf - uf)^2
end
# set optimizer
opt = ADAM(0.3)
# define (empty) data
x_train = Iterators.repeated((), 100)
# do one training round
train!(loss, ps, x_train, opt)
But I got the following warning
and then the parameters are not updated at all.
I guess that this is happening because it is having trouble auto-differentiating a function of the output of a differential equation…
But I could not find a way to work around it.
Any help here would be immensely appreciated!
Thanks in advance!
I also believe that I may not be using the appropriate (or up to date) Julia framework for this type of problem.
How would you approach the problem?
Btw, the package versions are:
[f6369f11] ForwardDiff v0.10.35
[91a5bcdd] Plots v1.38.14
[1ed8b502] SciMLSensitivity v7.32.0
[90137ffa] StaticArrays v1.5.25
[e88e6eb3] Zygote v0.6.61```