Lux.jl Neural ODE Initial Condition Optimization

How do we also optimize the initial condition of a neural ODE using Lux.jl? I have not been able to find any examples of this. Is there any way to just add the initial conditions as parameters to a neural ODE that is defined by Lux.Chain? Or we have to create our own custom layer if we want to optimize over initial conditions? Thanks you!

You don’t have to do anything special. It already differentiates with respect to the initial conditions of the ODE, so if you just make the initial conditions part of the state of the optimization problem then it’ll optimize them just like the parameters

Can you share what code you tried?

I see. I’ve just been following the main Neural ODE example from the DIffEqFlux website:

using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots

rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

dudt2 = Lux.Chain(x -> x.^3,
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

function predict_neuralode(p)
  Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
callback = function (p, l, pred; doplot = false)
  println(l)
  # plot current prediction against data
  if doplot
    plt = scatter(tsteps, ode_data[1,:], label = "data")
    scatter!(plt, tsteps, pred[1,:], label = "prediction")
    display(plot(plt))
  end
  return false
end

pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(optprob,
                                       ADAM(0.05),
                                       callback = callback,
                                       maxiters = 300)

Here, the parameter vector p won’t naturally contain the initial condition of the ODE unless I change it somehow.

So in the predict function just make u0 a function of p by subsetting it or using a component array and making it a separate component of what’s being optimized?

I just don’t understand how to create an entirely new set of parameters to optimize in Lux. In PyTorch this can be easily done by defining:

u0 = torch.nn.Parameter(torch.zeros(n, d))

Then I could optimize these parameters along with the weights and biases of my neural network.

Also, as a separate question. Are any of the optimization methods using continuous sensitivity analysis, or they discretize first and then differentiate?

We have some of all of it. If you want the details, read:

https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/

Just extend the componentarray. If no one else gets around to it I’ll post it when at a computer

Thank you, yes this would be super useful to see an example!

Hi @ChrisRackauckas, I’m still having trouble understanding how to add the initial condition into the optimization. Any advice or a short example would be much appreciated!

I don’t have time to run it locally but this is what it should look like on your code:

using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots

rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

dudt2 = Lux.Chain(x -> x.^3,
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

theta_init = ComponentArray(p=p,u0=u0)

function predict_neuralode(theta)
  Array(prob_neuralode(theta.u0, theta.p, st)[1])
end

function loss_neuralode(theta)
    pred = predict_neuralode(theta)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
callback = function (theta, l, pred; doplot = false)
  println(l)
  # plot current prediction against data
  if doplot
    plt = scatter(tsteps, ode_data[1,:], label = "data")
    scatter!(plt, tsteps, pred[1,:], label = "prediction")
    display(plot(plt))
  end
  return false
end

callback(theta_init, loss_neuralode(theta_init)...; doplot=true)

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((theta, x) -> loss_neuralode(theta), adtype)
optprob = Optimization.OptimizationProblem(optf, theta_init)

result_neuralode = Optimization.solve(optprob,
                                       ADAM(0.05),
                                       callback = callback,
                                       maxiters = 300)

let me know if that worked.

Chris’ example worked for me if you change the argument order of x and theta to this

optf = Optimization.OptimizationFunction((theta, x) -> loss_neuralode(theta), adtype)

or at least it seems to update both p and u0 and reduced the loss.

Cool, updated.

Thanks, yes it works great!

Hello, I’m looking for a method to optimize the parameters p in ordinary differential equations (ODEs) given an initial state.
for example having the
function lotka_volterra(du,u,p,t)
x, y, t = u
p = Float32[1.5;1.0;3.0;1.0]
α, β, δ, γ = p
du[1] = dx = αx - βxy
du[2] = dy = -δ
y + γxy
end

At the end of the optimization, I want to have the four optimized parameters p.

In your example these are the results of the p parameters of the network
the optimized params are: (layer_1 = Float32, layer_2 = (weight = Float32[0.34051737 -0.7225697; 1.2243054 -0.8738414; -0.12223241 -0.6008241; -0.70983696 -1.4648222; 0.4241502 1.1527704; -0.6206625 -1.6507087; -0.57573557 1.7884141; 0.23209119 0.038780622; -0.9749873 -1.4322469; -0.3459559 -0.22799775; 0.39861414 0.044922028; -0.84877014 -0.5160258; -0.13094257 -0.20170033; -0.34519777 -0.2244911; -0.18768083 0.43725845; -0.3004056 -0.01639044; 1.108202 -0.35414252; -0.5504589 -0.43984842; -0.34984672 -0.029465988; -0.14292195 -0.6240744; -0.39533982 -0.34633142; -0.4032572 0.354355; -0.017074319 0.24285501; -0.015675506 0.39289188; 1.5019082 -0.5584429; 0.7619853 1.3267232; -0.34526956 -0.5387412; -0.29242262 0.02992734; 0.30723917 -0.02444204; 0.81408167 0.69734114; 0.21127564 0.21938409; -0.3436895 -0.18252324; 0.45605284 0.2586919; 0.3277176 0.014252878; 0.50628716 -0.40785298; 0.001982623 0.6751017; 0.21485913 -0.5226195; 0.30743128 -0.36390322; -0.3012125 -1.0272802; 0.13220383 0.6269974; 0.53583395 0.436395; 1.5116845 -0.37546444; 1.2754116 -0.10419978; 1.4571118 -0.4402553; 0.0025211375 -0.659235; -0.30740207 -0.0025542588; -0.2537376 1.1365632; 0.052885212 -1.0852506; 0.16346186 0.09797427; -0.25863555 -0.08519906], bias = Float32[0.46891838; 0.09940664; 0.7764012; -0.069577925; -0.12990333; -0.3497002; -0.12936857; 0.39356333; -0.08821051; -0.59357876; 0.14079644; -0.00848567; 0.10761932; -0.43818557; 0.52104026; -0.40047675; 0.32094616; 0.16729179; -0.52638865; -0.085059986; -0.20005411; -0.32029766; 0.90473306; 0.46662065; 0.1088003; 0.10496092; 0.29941058; -0.26852453; 0.6607854; -0.011284352; -0.43031856; -0.45604643; -0.00265058; 0.3460428; 0.5557437; -0.63715184; 0.3707272; 0.3370264; -0.40361795; -0.80547905; 0.06156648; 0.22032294; 0.23894627; 0.07139194; 0.21812952; -0.4623675; 0.47308782; 0.61355937; 0.06721647; -0.8160552;;]), layer_3 = (weight = Float32[0.4324835 0.70755464 0.46855786 -0.1796102 0.22188674 0.59805274 -0.51416093 -0.35724756 -0.21273749 0.52336544 -0.0044735693 -0.26985022 0.438351 0.6619182 0.20350961 0.41731086 0.3725451 -0.27294567 0.71996355 -0.06931515 0.4757869 -0.36038798 -0.25295976 0.1992143 0.577134 -0.49909353 -0.4798156 -0.020222109 -0.520942 0.24650696 0.17378053 0.63049793 -0.41277507 -0.89447844 0.32424092 -0.63996243 0.5119745 0.29484195 0.5265959 -0.4097465 -0.55703264 0.36099026 0.26792797 0.5492287 0.1269902 0.3701654 -0.6522282 0.5628996 -0.13883987 0.3663721; 0.060457367 0.34252745 -0.20998807 -0.3300791 0.17516866 0.087635726 0.04626769 0.089912094 -0.28057852 0.22724333 0.5717837 -0.4407451 -0.10011728 0.26142254 -0.20981292 -0.1810826 0.53688323 -0.20079571 -0.12255657 -0.14257194 0.011710148 -0.2797389 -0.28422657 -0.061295003 0.42703402 0.18234691 -0.6927872 -0.2809378 0.17702024 0.3156151 0.37871683 0.22575484 0.13513336 0.91949564 0.22705488 0.12086328 0.101462685 0.17798616 0.21574336 0.18671027 0.034441907 0.50754493 0.39856842 0.5091696 -0.051556837 -0.10263984 -0.41046047 -0.12995869 0.107172966 0.08848082], bias = Float32[-0.44996643; -0.28240907;;]))the optimized initial conditions are: Float32[1.9961743, 0.49671444]

Lotka-Volterra parameter estimation is exactly the first tutorial:

https://docs.sciml.ai/SciMLSensitivity/dev/tutorials/parameter_estimation_ode/

1 Like