Hello everyone,
following this (DiffEqFlux.jl – A Julia Library for Neural Differential Equations) introduction to DiffEqFlux.jl I’ve set up a slightly modified example (I’m using a different diff. equation). My code works fine and I’m quite happy with the results.
However, for later usage I need to train the neural network with an excitation signal and not just from initial conditions. I figured out how to define an ODEProblem for a given excitation but I’m failing to do the same for my NeuralODE.
Hence my question, how can I modify my working code such that the NN is excited by an external signal? I.e. using the function linear_pendulum_excited and pass the excitation also to the NN.
I am aware of a rather similar question by Volker (Fitting a dynamic system with an exogenous input (nonhomogenous neural ode) via DiffEqFlux - #25 by Volker). Here the problem is defined by using the conventional ODEProblem instead of NeuralODE which confused me quite a bit. Can I use NeuralODE or do I have to change my code?
Thank you very much in advance!
Please find my code below:
using DifferentialEquations
using Plots
using Flux, DiffEqFlux
function linear_pendulum(du,u,p,t)
x, y = u
α = p[1]
du[1] = dx = y
du[2] = dy = -x - α*y
end
function linear_pendulum_excited(du,u,p,t)
x, y = u
α = p[1]
du[1] = dx = y
du[2] = dy = -x - α*y + ex(t)
end
neuralODE = True
if neuralODE
# Create excitation signal
amp = 2
freq = 1
ex(t) = amp*sin(freq*t)
# Create a reference solution
u0 = [-1,2]
datasize = 500
tspan = (0.0, 30)
time = range(tspan[1],tspan[2],length=datasize)
p = [0.5]
ode = ODEProblem(linear_pendulum, u0, tspan, p)
ref = Array(solve(ode,Tsit5(),saveat=time))
# Build simple NN to approximate the ODE
dudt = Chain(Dense(2,10,tanh), Dense(10,2))
# Define the neural ODE problem
n_ode = NeuralODE(dudt,tspan,Tsit5(),saveat=time,reltol=1e-7,abstol=1e-9)
ps = Flux.params(n_ode)
# Define prediction step as one whole run of neural ODE problem
function predict_n_ode()
n_ode(u0)
end
# Loss function defined as elementwise distance bewteen pred. and ref.
loss_n_ode() = sum(abs2,ref .- predict_n_ode())
# Optimization part of the code
data = Iterators.repeated((), 100)
opt = ADAM(0.05)
# Callback function to observe training --> plotting
cb = function ()
display(loss_n_ode())
# plot current prediction against data
cur_pred = predict_n_ode()
pl = plot(time,ref[2,:],label="data")
plot!(pl,time,cur_pred[2,:],label="prediction")
display(plot(pl))
end
# Display the ODE with the initial parameter values.
cb()
# Start the training
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
end