following this (https://julialang.org/blog/2019/01/fluxdiffeq/) introduction to DiffEqFlux.jl I’ve set up a slightly modified example (I’m using a different diff. equation). My code works fine and I’m quite happy with the results.
However, for later usage I need to train the neural network with an excitation signal and not just from initial conditions. I figured out how to define an ODEProblem for a given excitation but I’m failing to do the same for my NeuralODE.
Hence my question, how can I modify my working code such that the NN is excited by an external signal? I.e. using the function linear_pendulum_excited and pass the excitation also to the NN.
I am aware of a rather similar question by Volker (Fitting a dynamic system with an exogenous input (nonhomogenous neural ode) via DiffEqFlux). Here the problem is defined by using the conventional ODEProblem instead of NeuralODE which confused me quite a bit. Can I use NeuralODE or do I have to change my code?
Thank you very much in advance!
Please find my code below:
using DifferentialEquations using Plots using Flux, DiffEqFlux function linear_pendulum(du,u,p,t) x, y = u α = p du = dx = y du = dy = -x - α*y end function linear_pendulum_excited(du,u,p,t) x, y = u α = p du = dx = y du = dy = -x - α*y + ex(t) end neuralODE = True if neuralODE # Create excitation signal amp = 2 freq = 1 ex(t) = amp*sin(freq*t) # Create a reference solution u0 = [-1,2] datasize = 500 tspan = (0.0, 30) time = range(tspan,tspan,length=datasize) p = [0.5] ode = ODEProblem(linear_pendulum, u0, tspan, p) ref = Array(solve(ode,Tsit5(),saveat=time)) # Build simple NN to approximate the ODE dudt = Chain(Dense(2,10,tanh), Dense(10,2)) # Define the neural ODE problem n_ode = NeuralODE(dudt,tspan,Tsit5(),saveat=time,reltol=1e-7,abstol=1e-9) ps = Flux.params(n_ode) # Define prediction step as one whole run of neural ODE problem function predict_n_ode() n_ode(u0) end # Loss function defined as elementwise distance bewteen pred. and ref. loss_n_ode() = sum(abs2,ref .- predict_n_ode()) # Optimization part of the code data = Iterators.repeated((), 100) opt = ADAM(0.05) # Callback function to observe training --> plotting cb = function () display(loss_n_ode()) # plot current prediction against data cur_pred = predict_n_ode() pl = plot(time,ref[2,:],label="data") plot!(pl,time,cur_pred[2,:],label="prediction") display(plot(pl)) end # Display the ODE with the initial parameter values. cb() # Start the training Flux.train!(loss_n_ode, ps, data, opt, cb = cb) end