Hi!
My academic or theoretical vocabulary isn’t extensive, so please excuse any lack of terminology on my part. Please point the correct terminology out to me, I’m more than happy to learn.
I’ve been enjoying and following along with the DiffEqFlux announcement blog post, and have found things absolutely great so far.
Here’s a working example of a simple spring-mass-damper system that I first solve using the regular DifferentialEquations.jl library (in order to find my target values), and then proceed to solve using a Neural ODE (as per the blog post):
using DifferentialEquations
using Flux, DiffEqFlux
# The system I'm trying to solve
function msd_system(du, u, p, t)
m, k, c = p # Mass, spring, damper
g = 9.81
du[1] = u[2] # x = ẋ
du[2] = (-g*m - k*u[1] - c*u[2])/m
end
# Parameters
m = 1
k = 5
c = 1
p = [m, k ,c]
u0 = Float32[1., 0.]
tspan = (0., 4.)
ts = range(0., 4., length=300)
prob = ODEProblem(msd_system, u0, tspan, p)
sol = solve(prob, Tsit5(), saveat=ts);
# ODE Neural Network
dudt = Chain(
Dense(2, 50, swish),
Dense(50, 2)
)
ps = Flux.params(dudt);
n_ode = x-> neural_ode(dudt, x, tspan, Tsit5(), saveat=ts)
function predict_n_ode()
n_ode(u0)
end
loss_n_ode() = sum(abs, sol .- predict_n_ode())
# Callback
cb = function()
display(loss_n_ode())
end
# Training
data = Iterators.repeated((), 1000)
opt = ADAM()
cb() # Test call
Flux.train!(loss_n_ode, ps, data, opt, cb=Flux.throttle(cb, 1))
This works perfectly, as expected.
The problem: Let’s modify the original spring-mass-damper system to include a time-dependent acceleration (eg. someone has decided to poke the mass — F_acc(t)). Our system then becomes:
# Modified system
function msd_system_modified(du, u, p, t)
m, k, c, F_acc = p # Mass, spring, damper, Force function
g = 9.81
du[1] = u[2] # x = ẋ
du[2] = (F_acc(t) - g*m - k*u[1] - c*u[2])/m # Modified
end
Here, our regular ODE solver will still succeed without complaint. However, our Neural ODE (which only receives the current state, u
, and returns the derivative du
, will not be able to learn how to model this time-dependent force, since it is blissfully unaware of t
.
I’d like to be able to do something along the following:
model = Chain(
Dense(3, 50, swish), # <-- (u[1], u[2], t), for example
Dense(50, 2)
)
where I have an additional input parameter that I can use to pass in instantaneous values of t
.
I’ve had success implementing something like this with the Python HIPS/Autograd project, where you have a little more fine-grained control over where parameters go, but as a very recent newcomer to Julia (nor an expert in automatic differentiation) I’m a bit lost how I could do something similar using DiffEqFlux.jl.
Any help or pointers in the right direction? Thanks!