DiffEqFlux with time as additional input to Neural ODE

Hi!

My academic or theoretical vocabulary isn’t extensive, so please excuse any lack of terminology on my part. Please point the correct terminology out to me, I’m more than happy to learn.

I’ve been enjoying and following along with the DiffEqFlux announcement blog post, and have found things absolutely great so far.

Here’s a working example of a simple spring-mass-damper system that I first solve using the regular DifferentialEquations.jl library (in order to find my target values), and then proceed to solve using a Neural ODE (as per the blog post):

using DifferentialEquations
using Flux, DiffEqFlux

# The system I'm trying to solve
function msd_system(du, u, p, t)
    m, k, c = p  # Mass, spring, damper
    g = 9.81
    du[1] = u[2]  # x = ẋ
    du[2] = (-g*m - k*u[1] - c*u[2])/m
end

# Parameters
m = 1
k = 5
c = 1
p = [m, k ,c]
u0 = Float32[1., 0.]
tspan = (0., 4.)
ts = range(0., 4., length=300)

prob = ODEProblem(msd_system, u0, tspan, p)
sol = solve(prob, Tsit5(), saveat=ts);

# ODE Neural Network
dudt = Chain(
    Dense(2, 50, swish),
    Dense(50, 2)
)
ps = Flux.params(dudt);

n_ode = x-> neural_ode(dudt, x, tspan, Tsit5(), saveat=ts)

function predict_n_ode()
    n_ode(u0)
end

loss_n_ode() = sum(abs, sol .- predict_n_ode())

# Callback
cb = function()
    display(loss_n_ode())
end

# Training
data = Iterators.repeated((), 1000)
opt = ADAM()
cb()  # Test call
Flux.train!(loss_n_ode, ps, data, opt, cb=Flux.throttle(cb, 1))

This works perfectly, as expected.

The problem: Let’s modify the original spring-mass-damper system to include a time-dependent acceleration (eg. someone has decided to poke the mass — F_acc(t)). Our system then becomes:

# Modified system
function msd_system_modified(du, u, p, t)
    m, k, c, F_acc = p  # Mass, spring, damper, Force function
    g = 9.81
    du[1] = u[2]  # x = ẋ
    du[2] = (F_acc(t) - g*m - k*u[1] - c*u[2])/m  # Modified
end

Here, our regular ODE solver will still succeed without complaint. However, our Neural ODE (which only receives the current state, u, and returns the derivative du, will not be able to learn how to model this time-dependent force, since it is blissfully unaware of t.

I’d like to be able to do something along the following:

model = Chain(
        Dense(3, 50, swish),  # <-- (u[1], u[2], t), for example
        Dense(50, 2)
    )

where I have an additional input parameter that I can use to pass in instantaneous values of t.

I’ve had success implementing something like this with the Python HIPS/Autograd project, where you have a little more fine-grained control over where parameters go, but as a very recent newcomer to Julia (nor an expert in automatic differentiation) I’m a bit lost how I could do something similar using DiffEqFlux.jl.

Any help or pointers in the right direction? Thanks!

1 Like

The Neural ODE layer is a helper for a very specific case. Directly define the neural ODE if you want to do more fancy things, like have part be a known ODE and part be neural, or add some time dependencies:

I am not sure we can easily add it to the layer function without cluttering the interface a bit, but open an issue and we can look into it.

1 Like

Ah, ok! Thanks for clarifying – it makes a lot more sense that neural_ode is a helper. I’m perfectly happy with defining things directly when I need more flexibility — I believed the documentation was implying that using neural_ode was the de-facto way in which to define the Neural ODE.

It might be worth including a more simplistic example when demonstrating how to do something bit more custom (since while the Mixed Neural DEs shows how to define a Neural ODE directly, it’s a very particular use case, as mixed in with a lot of other complexity).

Using your advice (thanks!), this is what I currently have and seems to be working a treat:

using DifferentialEquations
using Flux, DiffEqFlux
using Plots

# The system I'm trying to solve
function msd_system(du, u, p, t)
    m, k, c = p  # Mass, spring, damper

    # Hacky time-dependent force
    if t > 1 && t < 2
        F = 3*9.81
    else
        F = 0.
    end

    g = 9.81
    du[1] = u[2]  # x = ẋ
    du[2] = (F-g*m - k*u[1] - c*u[2])/m
end


# Parameters --- normal ODE
m = 1.
k = 5.
c = 1.
p = [m, k ,c]

u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 4.0f0)
ts = range(0.0f0, 4.0f0, length=300)

prob = ODEProblem(msd_system, u0, tspan, p)
sol = solve(prob, Tsit5(), saveat=ts);  # This will act as our target

# -- ODE Neural Network --

# Model
model = Chain(
    Dense(3, 50, swish),
    Dense(50, 2)
)
ps_m = Flux.params(model)

# Custom Neural ODE
function dudt_(u, p, t)
    input = [u; t]
    Flux.Tracker.collect(model(input))
end

p = Float32[0.0]
p = param(p)  # Seems like `p` must be included in `diffeq_rd`, even if unused by Neural ODE?
_u0 = param(u0)
prob_n_ode = ODEProblem(dudt_, u0, tspan)
diffeq_rd(p, prob_n_ode, Tsit5())  # Test run

function predict_rd()
    Flux.Tracker.collect(diffeq_rd(p, prob_n_ode, Tsit5(), saveat=ts, u0=_u0))
end

loss_rd() = sum(abs2, sol .- predict_rd())
loss_rd()  # Test run

# Callback
cb = function()
    display(loss_rd())
end

data = Iterators.repeated((), 1000)
opt = ADAM()
cb()  # Test call
Flux.train!(loss_rd, ps_m, data, opt, cb=Flux.throttle(cb, 1))

As a final question, could you perhaps clarify on what seems to be the required inclusion of tracked parameters p in diffeq_rd and in the direct Neural ODE definition, even if it isn’t a dependency for the Neural ODE? Is there no way to exclude them? Seems messy otherwise.
I’ve commented the appropriate line.

Thanks again — and apologies for the misunderstanding on my part.
Michael.

Thanks. I’ll add an example to the README for directly declaring the simple Neural ODE. I agree that’s probably a good place for people to start tweaking from.

I would do out of place here instead:

function msd_system(u, p, t)
    m, k, c = p  # Mass, spring, damper

    # Hacky time-dependent force
    if t > 1 && t < 2
        F = 3*9.81
    else
        F = 0.
    end

    g = 9.81
    Flux.Tracker.collect(u[2],(F-g*m - k*u[1] - c*u[2])/m)
end

Flux’s AD like this form better, so it’ll be faster for reverse-mode autodiff (but slower for adjoints).

1 Like

Ah, fantastic! Thank you for the suggestion.

Is there a resource for these kinds of performance trade-offs? I’ve been finding it a bit difficult with developing the correct “nose” for what will better for certain scenarios.

I went back through the documentation and realized… yes, if one doesn’t know the implementation details this is hard to snuff out. Thus there is a new section which discusses performance:

I also realized that many users didn’t realize that I was demoing all combinations of forward/reverse/adjoint + in-place/out-of-place + CPU/GPU as a demonstration, and that some of the examples were purposely setup for their performance regimes. I made the examples more clear and noted where the performance regimes occur. Let me know if you’d like to see any more changes.

2 Likes

Hi Chris,

I’ve just had a quick read through the new README. This is fantastic. Greatly appreciate the update, this makes things a lot clearer to follow. Can’t thank you enough.