Creating a Neural ODE in Lux that explicitly depends on time


I am very excited to use Lux for Neural ODE training but am having trouble understanding how to extend neural ODEs to accept time t as an argument.

I noticed this thread: DiffEqFlux with time as additional input to Neural ODE - #3 by Michael_Struwig
which discusses how to do this for neural ODEs in Flux but this doesn’t seem to translate directly into the Lux interface. Any help would be much appreciated!

Just put t into the neural network.

This thread’s main points all work exactly the same in the Lux interface. What did you try?

1 Like

I tried just putting “t” in the network but needed to actually redefine the neural ODE in the following way (below I allow this new neural ODE to accept any function of time “control(t)”)

function ControlledNeuralODE(model::Lux.AbstractExplicitLayer, control;
  sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
  tspan=(0.0f0, 1.0f0),
  return ControlledNeuralODE(model, control, solver, sensealg, tspan, kwargs)

function (n::ControlledNeuralODE)(x, ps, st)
  function dudt(u, p, t)
    c = n.control(t)
    u_, st = n.model([u; c], p, st)
    return u_
  prob = ODEProblem(ODEFunction(dudt), x, n.tspan, ps)
  return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st

Just wanted to verify that this is a reasonable way of redefining the neural ODE. I followed the same outline as in: MNIST Classification using NeuralODE - Lux.jl

That looks fine as long as the input size for the model is correct. Did it not work?

Okay perfect, yes the way I wrote it seems to do the trick!