Minimum Working Example (MWE) showing error in Universal Differential Equation (UDE) implementation

The following code gives a Minimum Working Example for UDE which I wrote. But unfortunately it is showing error. When I run the code in VS Code the terminal crashes.

using OrdinaryDiffEq , SciMLSensitivity ,Optimization, OptimizationOptimisers,OptimizationOptimJL, LineSearches
using Statistics
using StableRNGs, Lux, Zygote , Plots , ComponentArrays

rng = StableRNG(11)

# Generating training data
function actualODE!(du,u,p,t,T∞,I)
    
    Cbat  =  5*3600 
    du[1] = -I/Cbat

    C₁ = -0.00153 # Unit is s-1
    C₂ = 0.020306 # Unit is K/J

    R0 = 0.03 # Resistance set a 30mohm

    Qgen =(I^2)*R0

    du[2] = (C₁*(u[2]-T∞)) + (C₂*Qgen)

end

t1 = collect(0:1:3400)
T∞1,I1 = 298.15,5

actualODE1!(du,u,p,t) = actualODE!(du,u,p,t,T∞1,I1)

prob = ODEProblem(actualODE1!,[1.0,T∞1],(t1[1],t1[end]))
solution = solve(prob,Tsit5(),saveat = t1)
X = Array(solution)
T1 = X[2,:]
# Plotting the results
plot(solution[2,:],color = :red,label = ["True Data" nothing])


# Defining the neural network
const U = Lux.Chain(Lux.Dense(3,20,tanh),Lux.Dense(20,20,tanh),Lux.Dense(20,1))
_para,st = Lux.setup(rng,U)
const _st = st

function NODE_model!(du,u,p,t,T∞,I)

    Cbat = 5*3600
    du[1] = -I/Cbat

    C₁ = -0.00153
    C₂ = 0.020306

    G = I*(U([u[1],u[2],I],p,_st)[1][1])

    du[2] = (C₁*(u[2]-T∞)) + (C₂*G)

end

NODE_model1!(du,u,p,t) = NODE_model!(du,u,p,t,T∞1,I1)
prob1 = ODEProblem(NODE_model1!,[1.0,T∞1],(t1[1],t1[end]),_para)

function loss(θ)
    _prob1 = remake(prob1,p=θ)
    _sol = Array(solve(_prob1,Tsit5(),saveat = t1))
    loss1 = mean(abs2,T1.-_sol[2,:])
    return loss1
end

losses = Float64[]

callback = function(state,l)
    push!(losses,l)
    println("RMSE Loss at iteration $(length(losses)) is $sqrt(l)")
    
    return false

end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x),adtype)
optprob = Optimization.OptimizationProblem(optf,ComponentVector{Float64}(_para))

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(),callback = callback,maxiters = 500)

Before crashing a warning about EnzymeVJP is shown there after a lot of messages come rapidly and terminal crashes. Due to the crashing, I couldn’t copy the messages. But I took some screenshots which I am attaching.


Does anybody know why this happens? Is the same issue occuring in your system?

Update on the code. I was able to run it when I changed the solver settings. The loss function was modified to the following

function loss(θ)
    _prob1 = remake(prob1,p=θ)
    _sol = Array(solve(_prob1,Tsit5(),saveat = t1,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
    loss1 = mean(abs2,T1.-_sol[2,:])
    return loss1
end

I tried using the BFGS algorithm after ADAM using the following lines of code

optprob2 = Optimization.OptimizationProblem(optf,res1.u)
res2 = Optimization.solve(optprob2,BFGS(),callback=callback,maxiters=50)

The loss function is reducing and the parameters are updating but it is showing retcode:Failure . Does anyone know why this happens?

Please let me know if you have any idea about the issues I posted. Any help would be much appreciated

I have no idea what is happening, but a few comments. First, explain what you’re trying to do. I’m guessing you want a neural network to generate the forcing function for an ODE, but I’m too lazy to step through everything and figure it out.

Second, your MWE isn’t all that minimal. It would be easier for others if you simplify further, e.g. get rid of numerical parameters, use first-order ODE, one where you know the optimum. Then maybe you can produce by hand the NODE that works, to use as a reference for the optimization. Also before you use NODE, why not just try to optimize a constant forcing function, so you’re testing Enzyme without an ANN.

Third, I personally wouldn’t use a callback to log data. The ODE solvers already provide ways to log the info you need, which lets them manage allocations. Here the push! allocates, and it will store every query the solver attempts. Not sure if it’s a problem, but I wouldn’t recommend it.

When the terminal crashes, I suspect running out of memory and other system-like problems. Failed retcode suggests the optimization didn’t converge, which is why it would help to make the MWE more minimal.

Sorry I can’t say more, just commenting as a naive non-expert.