Dropout in UDEs using SciML and Lux

Dear all,

I am trying to use a Dropout layer in the NN component of a UDE for parameter estimation. It does work if I use a setup based on Flux. However, if I switch to Lux I get an error message too long to be displayed fully in VSC. I based my code on this and this, so with Lux the implementation is as follows:

# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity #, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux #, Zygote, Plots

# Set a random seed for reproducible behaviour
using Random
rng = Random.default_rng()

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]

# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

# Add noise in terms of the mean
X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain(Lux.Dense(2, 5, rbf), 
                Lux.Dense(5, 5, rbf), 
                Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6))

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, Xₙ .- X̂)

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    return false

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 10)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

If one deletes Lux.Dropout(0.3) line the code runs without any problems. If anyone has an idea how to fix this, I’d appreciate it! :slight_smile:

1 Like

I checked the stacktrace. It seems dropout is being called with a vector of ForwardDiff Duals. I am not entire sure where ForwardDiff is entering into the picture here. @ChrisRackauckas? (I can add an extension in LuxLib to handle ForwardDiff)

@nina some pointers regarding the code. û = U(u, p, st)[1] drops the st. This is fine for models which don’t use the state, but once you introduce Dropout, you need to be careful with it. If you don’t update the model’s state, you essentially make dropout a deterministic layer, dropping the same set of activations every iteration.

1 Like

The system is small enough that the sensealg automation detected that forward mode would be faster and switched the ODE solve to forward.

Ah ok. In that case let me define a dispatch for forwarddiff in LuxLib.

Patched it Add ForwardDiff Extension: Dropout by avik-pal · Pull Request #269 · avik-pal/Lux.jl · GitHub (and tested the code posted). It should be merged and tagged by tomorrow.

Thank you avikpal, also w.r.t mentioning that obviously one has to be careful when passing the same st to the NN. So the way to solve this would be to use Lux.setup(rng, U) with every epoch and simply use the learned parameters instead of the ones you get through the setup function? I guess there are different choices regarding WHEN to resample the dropped nodes. Even though my question has a focus on how to efficiently code this, I’d also be interested in heuristics, if you have any in the context of UDEs. :slight_smile:
Regarding the initial problem: I’ll try out your patch once it was merged!

LuxLib is now tagged so you should be able to get the code working.

Regarding when to resample. A useful heuristic is to resample after evaluating the dynamical system completely. So using Dropout is likely not a good idea because it changes how the dynamical system behaves at every call of ude_dynamics!. Instead I would recommend using Layers - Lux.jl (specifically written for these situations). Once it is sampled, you need to ask it to resample manually. I don’t extensively use Optimization.jl but it should be possible to update st from the callback function to do so. My typical workflow is like DeepEquilibriumNetworks.jl/main.jl at main · SciML/DeepEquilibriumNetworks.jl · GitHub

1 Like